from numpy import dtype
from typing import Callable
import torch
from torch import nn
from pytorch3d.renderer import ImplicitRenderer
from typing import NamedTuple


class LebesgueRaymarcher(nn.Module):
    def forward(self, rays_features, ray_bundle=None, nerf_opacity=None, max_depth=1):
        # our model outputs not densities but opacities
        # assert torch.all((rays_features['grid'][..., 1:, :] - rays_features['grid'][..., :-1, :]) > -1e-2), 'grid must be monotonically increasing'
        # left_side = torch.minimum(rays_features['grid'][..., :-1, :], rays_features['grid'][..., 1:, :])
        summands = (rays_features['grid'][..., 1:, :] - rays_features['grid'][..., :-1, :]).abs() * rays_features['avg_values']
        images = summands.sum(-2)
        if 'opacity' in rays_features:
            return torch.cat((images, rays_features['opacity']), dim=-1)
        if nerf_opacity is not None:
            kwargs = nerf_opacity._prepare_kwargs_from_bundle(ray_bundle)
            opacities = nerf_opacity.opacity(max_depth * torch.ones(1, 1, dtype=images.dtype, device=images.device), **kwargs)
            opacities = opacities.view(*images.shape[:-1], -1)
            return torch.cat((images, opacities), dim=-1)
        return images


class AvgRaymarcher(nn.Module):
    def forward(self, rays_features, ray_bundle=None, nerf_opacity=None, max_depth=1):
        # our model outputs not densities but opacities
        # summands = 0.5 * (rays_densities[..., 1:, :] - rays_densities[..., :-1, :]) * (
        #                 rays_features[..., 1:, :] + rays_features[..., :-1, :])
        images = rays_features['avg_values'].sum(-2)
        if 'opacity' in rays_features:
            return torch.cat((images, rays_features['opacity']), dim=-1)
        if nerf_opacity is not None:
            kwargs = nerf_opacity._prepare_kwargs_from_bundle(ray_bundle)
            opacities = nerf_opacity.opacity(max_depth * torch.ones(1, 1, dtype=images.dtype, device=images.device), **kwargs)
            opacities = opacities.view(*images.shape[:-1], -1)
            return torch.cat((images, opacities), dim=-1)
        return images


# unreliased yet code from pytorch3d:
def _jiggle_within_stratas(bin_centers: torch.Tensor, single_sample_per_bin: bool = True, eps=0.0) -> torch.Tensor:
    """
    Performs sampling of 1 point per bin given the bin centers.
    More specifically, it replaces each point's value `z`
    with a sample from a uniform random distribution on
    `[z - delta_−, z + delta_+]`, where `delta_−` is half of the difference
    between `z` and the previous point, and `delta_+` is half of the difference
    between the next point and `z`. For the first and last items, the
    corresponding boundary deltas are assumed zero.
    Args:
        `bin_centers`: The input points of size (..., N); the result is broadcast
            along all but the last dimension (the rows). Each row should be
            sorted in ascending order.
    Returns:
        a tensor of size (..., N) with the locations jiggled within stratas/bins.
    """
    if single_sample_per_bin:
        # samples points between bin_centers, outputs N - 1 samples
        lower = bin_centers[..., :-1]
        upper = bin_centers[..., 1:]
    else:
        # Get intervals between bin centers.
        mids = 0.5 * (bin_centers[..., 1:] + bin_centers[..., :-1])
        upper = torch.cat((mids, bin_centers[..., -1:]), dim=-1)
        lower = torch.cat((bin_centers[..., :1], mids), dim=-1)
    # Samples in those intervals.
    jiggled = lower + (upper - lower - eps) * torch.rand_like(lower) + eps
    return jiggled


class ImplicitRendererDict(ImplicitRenderer):
    def __init__(self, 
            raysampler: Callable, 
            raymarcher: Callable, 
            stratified_resamling: bool = False,
            eps: float = 0.0,
        ) -> None:
        super().__init__(raysampler, raymarcher)
        self.stratified_resamling = stratified_resamling
        self.eps = eps

    def forward(
        self, cameras, volumetric_function, **kwargs
    ):

        if not callable(volumetric_function):
            raise ValueError('"volumetric_function" has to be a "Callable" object.')

        # first call the ray sampler that returns the RayBundle parametrizing
        # the rendering rays.
        ray_bundle = self.raysampler(
            cameras=cameras, volumetric_function=volumetric_function, **kwargs
        )
        if self.stratified_resamling:
            original_lengths = ray_bundle.lengths
            noisy_lengths = _jiggle_within_stratas(ray_bundle.lengths, single_sample_per_bin=True, eps=self.eps)
            ray_bundle = RayBundle(
                origins = ray_bundle.origins,
                directions = ray_bundle.directions,
                lengths = noisy_lengths,
                original_lengths = original_lengths,
                xys = ray_bundle.xys
            )
        else:
            ray_bundle = RayBundle(
                origins = ray_bundle.origins,
                directions = ray_bundle.directions,
                lengths = ray_bundle.lengths,
                original_lengths = ray_bundle.lengths,
                xys = ray_bundle.xys
            )
        # ray_bundle.origins - minibatch x ... x 3
        # ray_bundle.directions - minibatch x ... x 3
        # ray_bundle.lengths - minibatch x ... x n_pts_per_ray
        # ray_bundle.xys - minibatch x ... x 2

        # given sampled rays, call the volumetric function that
        # evaluates the densities and features at the locations of the
        # ray points
        rays_features_dict = volumetric_function(
            ray_bundle=ray_bundle, cameras=cameras, **kwargs
        )
        # ray_densities - minibatch x ... x n_pts_per_ray x density_dim
        # ray_features - minibatch x ... x n_pts_per_ray x feature_dim

        # finally, march along the sampled rays to obtain the renders
        images = self.raymarcher(
            rays_features=rays_features_dict,
            ray_bundle=ray_bundle,
            **kwargs
        )
        # images - minibatch x ... x (feature_dim + opacity_dim)

        return images, ray_bundle, rays_features_dict

class RayBundle(NamedTuple):
    """
    RayBundle parametrizes points along projection rays by storing ray `origins`,
    `directions` vectors and `lengths` at which the ray-points are sampled.
    Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
    Note that `directions` don't have to be normalized; they define unit vectors
    in the respective 1D coordinate systems; see documentation for
    :func:`ray_bundle_to_ray_points` for the conversion formula.
    """

    origins: torch.Tensor
    directions: torch.Tensor
    lengths: torch.Tensor
    original_lengths: torch.Tensor
    xys: torch.Tensor